#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Plotting functions for Sionna PHY"""

import numpy as np
import matplotlib.pyplot as plt
from itertools import compress
from sionna.phy.utils import sim_ber

def plot_ber(snr_db,
             ber,
             legend="",
             ylabel="BER",
             title="Bit Error Rate",
             ebno=True,
             is_bler=None,
             xlim=None,
             ylim=None,
             save_fig=False,
             path=""):
    """Plot error-rates

    Input
    -----
    snr_db: `numpy.ndarray` or `list` of `numpy.ndarray`
        Array defining the simulated SNR points

    ber: `numpy.ndarray` or `list` of `numpy.ndarray`
        Array defining the BER/BLER per SNR point

    legend: `str`, (default ""), or `list` of `str`
        Legend entries

    ylabel: `str`, (default "BER")
        y-label

    title: `str`, (default "Bit Error Rate")
        Figure title

    ebno: `bool`, (default `True`)
        If `True`, the x-label is set to
        "EbNo [dB]" instead of "EsNo [dB]".

    is_bler: `bool`, (default `False`)
        If `True`, the corresponding curve is dashed.

    xlim: `None` (default) | (`float`, `float`)
        x-axis limits

    ylim: `None` (default) | (`float`, `float`)
        y-axis limits

    save_fig: `bool`, (default `False`)
        If `True`, the figure is saved as `.png`.

    path: `str`, (default "")
        Path to save the figure (if ``save_fig`` is `True`)

    Output
    ------
    fig : `matplotlib.figure.Figure`
        Figure handle

    ax : matplotlib.axes.Axes
        Axes object
    """

    # legend must be a list or string
    if not isinstance(legend, list):
        assert isinstance(legend, str)
        legend = [legend]

    assert isinstance(title, str), "title must be str."

    # broadcast snr if ber is list
    if isinstance(ber, list):
        if not isinstance(snr_db, list):
            snr_db = [snr_db]*len(ber)

    # check that is_bler is list of same size and contains only bools
    if is_bler is None:
        if isinstance(ber, list):
            is_bler = [False] * len(ber) # init is_bler as list with False
        else:
            is_bler = False
    else:
        if isinstance(is_bler, list):
            assert (len(is_bler) == len(ber)), "is_bler has invalid size."
        else:
            assert isinstance(is_bler, bool), \
                "is_bler must be bool or list of bool."
            is_bler = [is_bler] # change to list

    # tile snr_db if not list, but ber is list

    fig, ax = plt.subplots(figsize=(16,10))

    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)

    if xlim is not None:
        plt.xlim(xlim)
    if ylim is not None:
        plt.ylim(ylim)

    plt.title(title, fontsize=25)
    # return figure handle
    if isinstance(ber, list):
        for idx, b in enumerate(ber):
            if is_bler[idx]:
                line_style = "--"
            else:
                line_style = ""
            plt.semilogy(snr_db[idx], b, line_style, linewidth=2)
    else:
        if is_bler:
            line_style = "--"
        else:
            line_style = ""
        plt.semilogy(snr_db, ber, line_style, linewidth=2)

    plt.grid(which="both")
    if ebno:
        plt.xlabel(r"$E_b/N_0$ (dB)", fontsize=25)
    else:
        plt.xlabel(r"$E_s/N_0$ (dB)", fontsize=25)
    plt.ylabel(ylabel, fontsize=25)
    plt.legend(legend, fontsize=20)
    if save_fig:
        plt.savefig(path)
        plt.close(fig)
    else:
        #plt.close(fig)
        pass
    return fig, ax

class PlotBER():
    """Provides a plotting object to simulate and store BER/BLER curves

    Parameters
    ----------
    title: `str`, (default "Bit/Block Error Rate")
        Figure title

    Input
    -----
    snr_db: `numpy.ndarray` or `list` of `numpy.ndarray`, `float`
        SNR values

    ber: `numpy.ndarray` or `list` of `numpy.ndarray`, `float`
        BER values corresponding to ``snr_db``

    legend: `str` or `list` of `str`
        Legend entries

    is_bler: `bool` or `list` of `bool`, (default [])
        If `True`, ``ber`` will be interpreted as BLER.

    show_ber: `bool`, (default `True`)
        If `True`, BER curves will be plotted.

    show_bler: `bool`, (default `True`)
        If `True`, BLER curves will be plotted.

    xlim: `None` (default) | (`float`, `float`)
        x-axis limits

    ylim: `None` (default) | (`float`, `float`)
        y-axis limits

    save_fig: `bool`, (default `False`)
        If `True`, the figure is saved as `.png`.

    path: `str`, (default "")
        Path to save the figure (if ``save_fig`` is `True`)
    """

    def __init__(self, title="Bit/Block Error Rate"):

        assert isinstance(title, str), "title must be str."
        self._title = title

        # init lists
        self._bers = []
        self._snrs = []
        self._legends = []
        self._is_bler = []

    # pylint: disable=W0102
    def __call__(self,
                 snr_db=[],
                 ber=[],
                 legend=[],
                 is_bler=[],
                 show_ber=True,
                 show_bler=True,
                 xlim=None,
                 ylim=None,
                 save_fig=False,
                 path=""):
        """Plot BER curves.

        """

        assert isinstance(path, str), "path must be str"
        assert isinstance(save_fig, bool), "save_fig must be bool"

        # broadcast snr if ber is list
        if isinstance(ber, list):
            if not isinstance(snr_db, list):
                snr_db = [snr_db]*len(ber)

        if not isinstance(snr_db, list):
            snrs = self._snrs + [snr_db]
        else:
            snrs = self._snrs + snr_db
        if not isinstance(ber, list):
            bers = self._bers + [ber]
        else:
            bers = self._bers + ber
        if not isinstance(legend, list):
            legends = self._legends + [legend]
        else:
            legends = self._legends + legend
        if not isinstance(is_bler, list):
            is_bler = self._is_bler + [is_bler]
        else:
            is_bler = self._is_bler + is_bler

        # deactivate BER/BLER
        if len(is_bler)>0: # ignore if object is empty
            if show_ber is False:
                snrs = list(compress(snrs, is_bler))
                bers = list(compress(bers, is_bler))
                legends = list(compress(legends, is_bler))
                is_bler = list(compress(is_bler, is_bler))

            if show_bler is False:
                snrs = list(compress(snrs, np.invert(is_bler)))
                bers = list(compress(bers, np.invert(is_bler)))
                legends = list(compress(legends, np.invert(is_bler)))
                is_bler = list(compress(is_bler, np.invert(is_bler)))

        # set ylabel
        ylabel = "BER / BLER"
        if np.all(is_bler): # only BLERs to plot
            ylabel = "BLER"
        if not np.any(is_bler): # only BERs to plot
            ylabel = "BER"

        # and plot the results
        plot_ber(snr_db=snrs,
                 ber=bers,
                 legend=legends,
                 is_bler=is_bler,
                 title=self._title,
                 ylabel=ylabel,
                 xlim=xlim,
                 ylim=ylim,
                 save_fig=save_fig,
                 path=path)

    ####public methods
    @property
    def title(self):
        """
        `str` : Get/set title of the plot
        """
        return self._title

    @title.setter
    def title(self, title):
        assert isinstance(title, str), "title must be string"
        self._title = title

    @property
    def ber(self):
        """
        `list` of `numpy.ndarray`, `float` : Stored BER/BLER values
        """
        return self._bers

    @property
    def snr(self):
        """
        `list` of `numpy.ndarray`, `float` : Stored SNR values
        """
        return self._snrs

    @property
    def legend(self):
        """
        `list` of `str` : Legend entries
        """
        return self._legends

    @property
    def is_bler(self):
        """
        `list` of `bool` : Indicates if a curve shall be interpreted as BLER
        """
        return self._is_bler

    def simulate(self,
                 mc_fun,
                 ebno_dbs,
                 batch_size,
                 max_mc_iter,
                 legend="",
                 add_ber=True,
                 add_bler=False,
                 soft_estimates=False,
                 num_target_bit_errors=None,
                 num_target_block_errors=None,
                 target_ber=None,
                 target_bler=None,
                 early_stop=True,
                 graph_mode=None,
                 distribute=None,
                 add_results=True,
                 forward_keyboard_interrupt=True,
                 show_fig=True,
                 verbose=True):
        # pylint: disable=line-too-long
        r"""Simulate BER/BLER curves for a given model and saves the results

        Internally calls :class:`sionna.phy.utils.sim_ber`.

        Input
        -----
        mc_fun: `callable`
            Callable that yields the transmitted bits `b` and the
            receiver's estimate `b_hat` for a given ``batch_size`` and
            ``ebno_db``. If ``soft_estimates`` is `True`, b_hat is
            interpreted as logit.

        ebno_dbs: `numpy.ndarray` of `float`
            SNR points to be evaluated

        batch_size: `tf.int`
            Batch-size for evaluation

        max_mc_iter: `int`
            Max. number of Monte-Carlo iterations per SNR point

        legend: `str`, (default "")
            Name to appear in legend

        add_ber: `bool`, (default `True`)
            Indicates if BER should be added to plot

        add_bler: `bool`, (default `True`)
            Indicate if BLER should be added to plot

        soft_estimates: `bool`, (default `False`)
            If `True`, ``b_hat`` is interpreted as logit and additional
            hard-decision is applied internally.

        num_target_bit_errors: `None` (default) | `int`
            Target number of bit errors per SNR point until the simulation
            stops

        num_target_block_errors: `None` (default) | `int`
            Target number of block errors per SNR point until the simulation
            stops

        target_ber: `None` (default) | `float`
            The simulation stops after the first SNR point
            which achieves a lower bit error rate as specified by
            ``target_ber``. This requires ``early_stop`` to be `True`.

        target_bler: `None` (default) | `float`
            The simulation stops after the first SNR point
            which achieves a lower block error rate as specified by
            ``target_bler``.  This requires ``early_stop`` to be `True`.

        early_stop: `bool`, (default `True`)
            If `True`, the simulation stops after the
            first error-free SNR point (i.e., no error occurred after
            ``max_mc_iter`` Monte-Carlo iterations).

        graph_mode: `None` (default) | "graph" | "xla"
            A string describing the execution mode of ``mc_fun``.
            Defaults to `None`. In this case, ``mc_fun`` is executed as is.

        distribute: `None` (default) | "all" | list of indices | `tf.distribute.strategy`
            Distributes simulation on multiple parallel devices. If `None`,
            multi-device simulations are deactivated. If "all", the workload
            will be automatically distributed across all available GPUs via the
            `tf.distribute.MirroredStrategy`.
            If an explicit list of indices is provided, only the GPUs with the
            given indices will be used. Alternatively, a custom
            `tf.distribute.strategy` can be provided. Note that the same
            `batch_size` will be used for all GPUs in parallel, but the number
            of Monte-Carlo iterations ``max_mc_iter`` will be scaled by the
            number of devices such that the same number of total samples is
            simulated. However, all stopping conditions are still in-place
            which can cause slight differences in the total number of simulated
            samples.

        add_results: `bool`, (default `True`)
            If `True`, the simulation results will be appended
            to the internal list of results.

        show_fig: `bool`, (default `True`)
            If `True`, a BER figure will be plotted.

        verbose: `bool`, (default `True`)
            If `True`, the current progress will be printed.

        forward_keyboard_interrupt: `bool`, (default `True`)
            If `False`, `KeyboardInterrupts` will be
            catched internally and not forwarded (e.g., will not stop outer
            loops). If `True`, the simulation ends and returns the intermediate
            simulation results.

        Output
        ------
        ber: `tf.float`
            Simulated bit-error rates

        bler: `tf.float`
            Simulated block-error rates
        """

        ber, bler = sim_ber(
                        mc_fun,
                        ebno_dbs,
                        batch_size,
                        soft_estimates=soft_estimates,
                        max_mc_iter=max_mc_iter,
                        num_target_bit_errors=num_target_bit_errors,
                        num_target_block_errors=num_target_block_errors,
                        target_ber=target_ber,
                        target_bler=target_bler,
                        early_stop=early_stop,
                        graph_mode=graph_mode,
                        distribute=distribute,
                        verbose=verbose,
                        forward_keyboard_interrupt=forward_keyboard_interrupt)

        if add_ber:
            self._bers += [ber]
            self._snrs +=  [ebno_dbs]
            self._legends += [legend]
            self._is_bler += [False]

        if add_bler:
            self._bers += [bler]
            self._snrs +=  [ebno_dbs]
            self._legends += [legend + " (BLER)"]
            self._is_bler += [True]

        if show_fig:
            self()

        # remove current curve if add_results=False
        if add_results is False:
            if add_bler:
                self.remove(-1)
            if add_ber:
                self.remove(-1)

        return ber, bler

    def add(self, ebno_db, ber, is_bler=False, legend=""):
        """Add static reference curves

        Input
        -----
        ebno_dbs: `numpy.ndarray` or `list` of `numpy.ndarray`, `float`
            SNR points

        ber: `numpy.ndarray` or `list` of `numpy.ndarray`, `float`
            BER  corresponding to each SNR point

        is_bler: `bool`, (default `False`)
            If `True`, ``ber`` is interpreted as BLER.

        legend: `str`, (default "")
            Legend entry
        """

        assert (len(ebno_db)==len(ber)), \
            "ebno_db and ber must have same number of elements."

        assert isinstance(legend, str), "legend must be str."
        assert isinstance(is_bler, bool), "is_bler must be bool."

        # concatenate curves
        self._bers += [ber]
        self._snrs +=  [ebno_db]
        self._legends += [legend]
        self._is_bler += [is_bler]

    def reset(self):
        """Removes all internal data"""
        self._bers = []
        self._snrs = []
        self._legends = []
        self._is_bler = []

    def remove(self, idx=-1):
        """Removes curve with index ``idx``

        Input
        ------
        idx: `int`
            Index of the dataset that should
            be removed. Negative indexing is possible.
        """

        assert isinstance(idx, int), "id must be int."

        del self._bers[idx]
        del self._snrs[idx]
        del self._legends[idx]
        del self._is_bler[idx]

